library(tidyverse)
library(lme4)
library(lmerTest)
library(logging)
library(mvtnorm)
library(mgcv)
# Compute the log-likelihood of a new dataset using a fit lme4 model.
logLik_test <- function(lm, test_X, test_y) {
predictions <- predict(lm, test_X, re.form=NA)
# Get std.dev. of residual, estimated from train data
stdev <- sigma(lm)
# For each prediction--observation, get the density p(obs | N(predicted, model_sigma)) and reduce
density <- sum(dnorm(test_y, predictions, stdev, log=TRUE))
return(density)
}
# Get per-prediction log-likelihood
logLik_test_per <- function(lm, test_X, test_y) {
predictions <- predict(lm, test_X, re.form=NA)
# Get std.dev. of residual, estimated from train data
stdev <- sigma(lm)
# For each prediction--observation, get the density p(obs | N(predicted, model_sigma))
densities <- dnorm(test_y, predictions, stdev, log=TRUE)
return(densities)
}
# Compute MSE of a new dataset using a fit lme4 model.
mse_test <- function(lm, test_X, test_y) {
return(mean((predict(lm, test_X, re.form=NA) - test_y) ^ 2))
}
#Sanity checks
#mylm <- gam(psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr"), data=train_data)
#c(logLik(mylm), logLik_test(mylm, train_data, train_data$psychometric))
#logLik_test(mylm, test_data, test_data$psychometric)
data = read.csv("../data/harmonized_results.csv")
all_data = data %>%
mutate(seed = as.factor(seed)) %>%
group_by(corpus, model, training, seed) %>%
mutate(prev_surp = lag(surprisal),
prev_code = lag(code),
prev_len = lag(len),
prev_freq = lag(freq),
prev_surp = lag(surprisal),
prev2_freq = lag(prev_freq),
prev2_code = lag(prev_code),
prev2_len = lag(prev_len),
prev2_surp = lag(prev_surp),
prev3_freq = lag(prev2_freq),
prev3_code = lag(prev2_code),
prev3_len = lag(prev2_len),
prev3_surp = lag(prev2_surp),
prev4_freq = lag(prev3_freq),
prev4_code = lag(prev3_code),
prev4_len = lag(prev3_len),
prev4_surp = lag(prev3_surp)) %>%
ungroup() %>%
# Filter back two for the dundee corpus. Filter back 1 for all other corpora
# NB this effectively removes all zero-surprisal rows, since early-sentence tokens don't have contiguous token history
filter((corpus == "dundee" & code == prev2_code + 2) | (corpus != "dundee" & code == prev4_code + 4)) %>%
select(-prev_code, -prev2_code, -prev3_code) %>%
drop_na()
all_data = all_data %>%
mutate(
model = as.character(model),
model = if_else(model == "gpt-2", "gpt2", model),
model = as.factor(model))
missing_rows = all_data %>% complete(nesting(corpus, code), nesting(model, training, seed)) %>%
group_by(corpus, code) %>%
filter(sum(is.na(surprisal)) > 0) %>%
ungroup() %>%
anti_join(all_data, by=c("corpus", "code", "model", "training", "seed"))
missing_rows %>% ggplot(aes(x=corpus, fill=factor(paste(model,training)))) + geom_bar(position=position_dodge(width=0.8))
print(missing_rows %>% group_by(model, training, seed, corpus) %>% summarise(n=n())) %>% arrange(desc(n))
# Compute the ideal number of model--seed--training observations per token.
to_drop = all_data %>%
group_by(corpus, code) %>% summarise(n = n()) %>% ungroup() %>%
group_by(corpus) %>% mutate( max_n = max(n)) %>% ungroup() %>%
filter(max_n != n) %>%
select(code, corpus)
#to_drop = all_data %>% group_by(corpus, code) %>% filter(n() != ideal_token_obs_count) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop), "observations corresponding to corpus tokens which are missing observations for some model."))
[0m2020-05-29 10:42:54 INFO::Dropping 10342 observations corresponding to corpus tokens which are missing observations for some model.[0m[0m[0m
loginfo(paste("Dropping", to_drop %>% group_by(corpus, code) %>% n_groups(), "tokens which are missing observations for some model."))
[0m2020-05-29 10:42:54 INFO::Dropping 10342 tokens which are missing observations for some model.[0m[0m[0m
all_data = all_data %>% anti_join(to_drop %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
[0m2020-05-29 10:42:55 INFO::After drop, 962274 observations ( 33117 tokens) remain.[0m[0m[0m
to_drop_zero_surps = all_data %>% group_by(corpus, code) %>% filter(any(surprisal == 0)) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop_zero_surps), "observations corresponding to corpus tokens which have surprisal zeros for some model."))
[0m2020-05-29 10:42:56 INFO::Dropping 116 observations corresponding to corpus tokens which have surprisal zeros for some model.[0m[0m[0m
loginfo(paste("Dropping", to_drop_zero_surps %>% group_by(corpus, code) %>% n_groups(), "tokens which have surprisal zeros for some model."))
[0m2020-05-29 10:42:56 INFO::Dropping 4 tokens which have surprisal zeros for some model.[0m[0m[0m
all_data = all_data %>% anti_join(to_drop_zero_surps %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
[0m2020-05-29 10:42:56 INFO::After drop, 962158 observations ( 33113 tokens) remain.[0m[0m[0m
to_drop_zero_psychs = all_data %>% group_by(corpus, code) %>% filter(any(psychometric == 0)) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop_zero_psychs), "observations corresponding to corpus tokens which have psychometric zeros for some model."))
[0m2020-05-29 10:42:57 INFO::Dropping 14935 observations corresponding to corpus tokens which have psychometric zeros for some model.[0m[0m[0m
loginfo(paste("Dropping", to_drop_zero_psychs %>% group_by(corpus, code) %>% n_groups(), "tokens which have psychometric zeros for some model."))
[0m2020-05-29 10:42:57 INFO::Dropping 515 tokens which have psychometric zeros for some model.[0m[0m[0m
all_data = all_data %>% anti_join(to_drop_zero_psychs %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
[0m2020-05-29 10:42:57 INFO::After drop, 947223 observations ( 32598 tokens) remain.[0m[0m[0m
# Compute linear model stats for the given training data subset and full test data.
# Automatically subsets the test data to match the relevant group for which we are training a linear model.
get_lm_data <- function(df, test_data, formula, fold, store_env) {
#this_lm <- gam(formula, data=df);
this_lm = lm(formula, data=df)
this_test_data <- semi_join(test_data, df, by=c("training", "model", "seed", "corpus"));
# Save lm to the global env so that we can access residuals later.
lm_name = paste(unique(paste(df$model, df$training, df$seed, df$corpus))[1], fold)
assign(lm_name, this_lm, envir=store_env)
summarise(df,
log_lik = as.numeric(logLik(this_lm, REML = F)),
test_lik = logLik_test(this_lm, this_test_data, this_test_data$psychometric),
test_mse = mse_test(this_lm, this_test_data, this_test_data$psychometric))
}
# For a previously fitted lm stored in store_env, get the residuals on test data of the relevant data subset.
get_lm_residuals <- function(df, fold, store_env) {
# Retrieve the relevant lm.
lm_name = paste(unique(paste(df$model, df$training, df$seed, df$corpus))[1], fold)
this_lm <- get(lm_name, envir=store_env)
mutate(df,
likelihood = logLik_test_per(this_lm, df, df$psychometric),
resid = df$psychometric - predict(this_lm, df, re.form=NA))
}
# Compute per-example delta-log-likelihood for the given test fold.
get_lm_delta_log_lik <- function(test_data, fold, baseline_env, full_env) {
lm_name = paste(unique(paste(test_data$model, test_data$training, test_data$seed, test_data$corpus))[1], fold)
baseline_lm <- get(lm_name, envir=baseline_env)
full_lm <- get(lm_name, envir=full_env)
delta_log_lik = logLik_test_per(full_lm, test_data, test_data$psychometric) - logLik_test_per(baseline_lm, test_data, test_data$psychometric)
return(cbind(test_data, delta_log_lik=delta_log_lik))
}
#####
# Define regression formulae.
# Eye-tracking regression: only use surprisal and previous surprisal; SPRT regression: use 2-back features.
#baseline_rt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr")
#baselie_sprt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr") + te(prev3_freq, prev3_len, bs = "cr") + te(prev4_freq, prev4_len, bs = "cr")
#full_rt_regression = psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + s(prev2_surp, bs = "cr", k = 20) + te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr")
#full_sprt_regression = psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + s(prev2_surp, bs = "cr", k = 20) + s(prev3_surp, bs = "cr", k = 20) + s(prev4_surp, bs = "cr", k = 20) + te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr") + te(prev3_freq, prev3_len, bs = "cr") + te(prev4_freq, prev4_len, bs = "cr")
baseline_rt_regression = psychometric ~ freq + prev_freq + prev2_freq + len + prev_len + prev2_len
baseline_sprt_regression = psychometric ~ freq + prev_freq + prev2_freq + prev3_freq + prev4_freq + len + prev_len + prev2_len + prev3_len + prev4_len
full_sprt_regression = psychometric ~ surprisal + prev_surp + prev2_surp + prev3_surp + prev4_surp + freq + prev_freq + prev2_freq + prev3_freq + prev4_freq + len + prev_len + prev2_len + prev3_len + prev4_len
full_rt_regression = psychometric ~ surprisal + prev_surp + prev2_surp + freq + prev_freq + prev2_freq + len + prev_len + prev2_len
#####
# Prepare frames/environments for storing results/objects.
baseline_results = data.frame()
full_model_results = data.frame()
baseline_residuals = data.frame()
full_residuals = data.frame()
log_lik_deltas = data.frame()
#Randomly shuffle the data
all_data<-all_data[sample(nrow(all_data)),]
#Create K equally size folds
K = 10
folds <- cut(seq(1,nrow(all_data)),breaks=K,labels=FALSE)
#Perform 10 fold cross validation
# Fit models for some fold of the data.
baseline_corpus = function(corpus, df, test_data, fold, env) {
if(corpus == "dundee") {
get_lm_data(df, test_data, baseline_rt_regression, fold, env)
} else {
get_lm_data(df, test_data, baseline_sprt_regression, fold, env)
}
}
full_model_corpus = function(corpus, df, test_data, fold, env) {
if(corpus[1] == "dundee") {
get_lm_data(df, test_data, full_rt_regression, fold, env)
} else {
get_lm_data(df, test_data, full_sprt_regression, fold, env)
}
}
# Prepare a new Environment in which we store fitted LMs, which we'll query later for residuals and other metrics.
baseline_env = new.env()
full_env = new.env()
for(i in 1:K) {
#Segement your data by fold using the which() function
testIndexes <- which(folds==i, arr.ind=TRUE)
test_data <- all_data[testIndexes, ]
train_data <- all_data[-testIndexes, ]
# Compute a baseline linear model for each model--training--seed--RT-corpus combination.
baselines = train_data %>%
group_by(model, training, seed, corpus) %>%
print(model) %>%
do(baseline_corpus(unique(.$corpus), ., test_data, i, baseline_env)) %>%
ungroup() %>%
mutate(seed = as.factor(seed),
fold = i)
baseline_results = rbind(baseline_results, baselines)
# Compute a full linear model for each model--training--seed-RT-corpus combination
full_models = train_data %>%
group_by(model, training, seed, corpus) %>%
do(full_model_corpus(unique(.$corpus), ., test_data, i, full_env)) %>%
ungroup() %>%
mutate(seed = as.factor(seed),
fold = i)
full_model_results = rbind(full_model_results, full_models)
# Compute delta-log-likelihoods
fold_log_lik_deltas = test_data %>%
group_by(model, training, seed, corpus) %>%
do(get_lm_delta_log_lik(., i, baseline_env, full_env)) %>%
ungroup()
log_lik_deltas = rbind(log_lik_deltas, fold_log_lik_deltas)
fold_baseline_residuals = test_data %>%
group_by(model, training, seed, corpus) %>%
do(get_lm_residuals(., i, baseline_env)) %>%
ungroup()
baseline_residuals = rbind(baseline_residuals, fold_baseline_residuals)
fold_full_residuals = test_data %>%
group_by(model, training, seed, corpus) %>%
do(get_lm_residuals(., i, full_env)) %>%
ungroup()
full_residuals = rbind(full_residuals, fold_full_residuals)
}
|==================================================================================================== | 56% ~2 s remaining
|====================================================================================================== | 57% ~2 s remaining
|============================================================================================================= | 60% ~1 s remaining
|=================================================================================================================== | 64% ~1 s remaining
|========================================================================================================================= | 67% ~1 s remaining
|=============================================================================================================================== | 70% ~1 s remaining
|===================================================================================================================================== | 74% ~1 s remaining
|=========================================================================================================================================== | 77% ~1 s remaining
|============================================================================================================================================= | 78% ~1 s remaining
|==================================================================================================================================================== | 82% ~1 s remaining
|========================================================================================================================================================== | 85% ~1 s remaining
|================================================================================================================================================================ | 89% ~0 s remaining
|==================================================================================================================================================================== | 91% ~0 s remaining
|====================================================================================================================================================================== | 92% ~0 s remaining
|============================================================================================================================================================================ | 95% ~0 s remaining
|================================================================================================================================================================================== | 99% ~0 s remaining
|====================================================================================================== | 57% ~2 s remaining
|========================================================================================================== | 59% ~1 s remaining
|============================================================================================================= | 60% ~1 s remaining
|================================================================================================================= | 62% ~1 s remaining
|=================================================================================================================== | 64% ~1 s remaining
|======================================================================================================================= | 66% ~1 s remaining
|========================================================================================================================= | 67% ~1 s remaining
|============================================================================================================================= | 69% ~1 s remaining
|=============================================================================================================================== | 70% ~1 s remaining
|=================================================================================================================================== | 73% ~1 s remaining
|===================================================================================================================================== | 74% ~1 s remaining
|========================================================================================================================================= | 76% ~1 s remaining
|============================================================================================================================================= | 78% ~1 s remaining
|================================================================================================================================================== | 81% ~1 s remaining
|==================================================================================================================================================== | 82% ~1 s remaining
|======================================================================================================================================================== | 84% ~1 s remaining
|========================================================================================================================================================== | 85% ~1 s remaining
|============================================================================================================================================================== | 88% ~0 s remaining
|================================================================================================================================================================ | 89% ~0 s remaining
|==================================================================================================================================================================== | 91% ~0 s remaining
|====================================================================================================================================================================== | 92% ~0 s remaining
|========================================================================================================================================================================== | 94% ~0 s remaining
|============================================================================================================================================================================ | 95% ~0 s remaining
|================================================================================================================================================================================ | 98% ~0 s remaining
|================================================================================================================================================================================== | 99% ~0 s remaining
|=========================================================================================================================== | 65% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|========================================================================================================================================================================================== | 98% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================== | 56% ~2 s remaining
|============================================================================================================ | 57% ~2 s remaining
|=================================================================================================================== | 60% ~1 s remaining
|======================================================================================================================= | 62% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|=================================================================================================== | 52% ~2 s remaining
|====================================================================================================== | 53% ~2 s remaining
|========================================================================================================== | 56% ~2 s remaining
|============================================================================================================ | 57% ~2 s remaining
|=================================================================================================================== | 60% ~2 s remaining
|======================================================================================================================= | 62% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|============================================================================================================ | 57% ~2 s remaining
|================================================================================================================ | 59% ~1 s remaining
|=================================================================================================================== | 60% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|=================================================================================================================================================== | 77% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|================================================================================================= | 51% ~2 s remaining
|====================================================================================================== | 53% ~2 s remaining
|========================================================================================================== | 56% ~2 s remaining
|============================================================================================================ | 57% ~2 s remaining
|================================================================================================================ | 59% ~2 s remaining
|=================================================================================================================== | 60% ~2 s remaining
|======================================================================================================================= | 62% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|========================================================================================================================================================================================== | 98% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~0 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|======================================================================================================================================================================================== | 97% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|========================================================================================================================================================================================== | 98% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|======================================================================================================================= | 62% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|=================================================================================================================================================== | 77% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|========================================================================================================================================================================================== | 98% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|============================================================================================================== | 58% ~1 s remaining
|=================================================================================================================== | 60% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|=================================================================================================================================================== | 77% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|============================================================================================================ | 57% ~2 s remaining
|================================================================================================================ | 59% ~1 s remaining
|=================================================================================================================== | 60% ~1 s remaining
|======================================================================================================================= | 62% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================== | 56% ~2 s remaining
|============================================================================================================ | 57% ~2 s remaining
|=================================================================================================================== | 60% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|=================================================================================================================================================== | 77% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|====================================================================================================== | 53% ~2 s remaining
|========================================================================================================== | 56% ~2 s remaining
|============================================================================================================ | 57% ~2 s remaining
|================================================================================================================ | 59% ~2 s remaining
|=================================================================================================================== | 60% ~1 s remaining
|======================================================================================================================= | 62% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|========================================================================================================================================================================================== | 98% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|=================================================================================================================== | 60% ~1 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|=================================================================================================================================================== | 77% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~0 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|========================================================================================================================================================================================== | 98% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|================================================================================================================================== | 68% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|=================================================================================================================================================== | 77% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================== | 85% ~1 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
|========================================================================================================================= | 64% ~1 s remaining
|============================================================================================================================= | 66% ~1 s remaining
|================================================================================================================================ | 67% ~1 s remaining
|==================================================================================================================================== | 69% ~1 s remaining
|====================================================================================================================================== | 70% ~1 s remaining
|========================================================================================================================================== | 73% ~1 s remaining
|============================================================================================================================================= | 74% ~1 s remaining
|================================================================================================================================================= | 76% ~1 s remaining
|===================================================================================================================================================== | 78% ~1 s remaining
|========================================================================================================================================================== | 81% ~1 s remaining
|============================================================================================================================================================ | 82% ~1 s remaining
|================================================================================================================================================================ | 84% ~1 s remaining
|================================================================================================================================================================== | 85% ~0 s remaining
|======================================================================================================================================================================= | 88% ~0 s remaining
|========================================================================================================================================================================= | 89% ~0 s remaining
|============================================================================================================================================================================= | 91% ~0 s remaining
|=============================================================================================================================================================================== | 92% ~0 s remaining
|==================================================================================================================================================================================== | 94% ~0 s remaining
|====================================================================================================================================================================================== | 95% ~0 s remaining
|========================================================================================================================================================================================== | 98% ~0 s remaining
|============================================================================================================================================================================================ | 99% ~0 s remaining
#write.csv(full_residuals, "../data/analysis_checkpoints/full_residuals.csv")
#write.csv(baseline_residuals, "../data/analysis_checkpoints/baseline_residuals.csv")
model_deltas = log_lik_deltas %>%
group_by(model, training, seed, corpus) %>%
summarise(mean_delta_log_lik = mean(delta_log_lik),
sem_delta_log_lik = sd(delta_log_lik) / sqrt(length(delta_log_lik)))
write.csv(full_model_results, "../data/analysis_checkpoints/full_model_result.csv")
write.csv(baseline_results, "../data/analysis_checkpoints/baseline_results.csv")
#full_model_results = read.csv("../data/analysis_checkpoints/ffull_model_results.csv")
#baseline_results = read.csv("../data/analysis_checkpoints/fbaseline_resultsb.csv")
metric <- "ΔLogLik"
#metric <- "-ΔMSE"
# # Select the relevant metric.
model_deltas = model_deltas %>%
# Retrieve the current test metric
mutate(delta_test_mean = mean_delta_log_lik,
delta_test_sem = sem_delta_log_lik) %>%
# mutate(delta_test_mean = mean_delta_mse,
# delta_test_sem = sem_delta_mse)
# Remove the raw metrics.
select(-mean_delta_log_lik, -sem_delta_log_lik,
#-mean_delta_mse, -sem_delta_mse
)
model_deltas
# Sanity check: training on train+test data should yield improved performance over training on just training data. (When evaluating on test data.)
# full_baselines = all_data %>%
# group_by(model, training, seed, corpus) %>%
# summarise(baseline_train_all_test_lik = logLik_test(lm(psychometric ~ len + freq + sent_pos, data=.), semi_join(test_data, ., by=c("training", "model", "seed", "corpus")), semi_join(test_data, ., by=c("training", "model", "seed", "corpus"))$psychometric)) %>%
# ungroup()
# full_baselines
#
# full_baselines %>%
# right_join(baselines, by=c("seed", "training", "model", "corpus")) %>%
# mutate(delta=baseline_train_all_test_lik-baseline_test_lik) %>%
# select(-baseline_lik) # %>%
# #select(-baseline_test_lik, -baseline_train_all_test_lik, -baseline_lik, -baseline_test_mse)
language_model_data = read.csv("../data/model_metadata.csv") %>%
mutate(model = as.character(model),
model = if_else(model == "gpt-2", "gpt2", model),
model = as.factor(model)) %>%
mutate(train_size = case_when(str_starts(training, "bllip-lg") ~ 42,
str_starts(training, "bllip-md") ~ 15,
str_starts(training, "bllip-sm") ~ 5,
str_starts(training, "bllip-xs") ~ 1),
# Training vocabulary usually covaries with the training corpus.
# But BPE models share a vocabulary across training corpora.
training_vocab=as.factor(ifelse(str_detect(training, "gptbpe"), "gptbpe", as.character(training))),
training_source=as.factor(str_replace(as.character(training), "-gptbpe", ""))
) %>%
mutate(seed = as.factor(seed)) %>%
select(-pid, -test_loss) %>%
distinct(model, training, seed, .keep_all = TRUE)
table(language_model_data$seed)
0 111 120 922 1111 3602 4301 7245 7877 28066 28068 44862 51272 64924 1581807512 1581807578 1581861474 1581955288
4 7 6 5 4 1 1 1 1 1 1 1 1 1 1 1 1 1
1582126320 1586986276 1587139950
1 1 1
table(model_deltas$seed)
111 120 607 922 1111 3602 4301 7245 7877 28066 28068 44862 51272 64924 1581807512 1581807578 1581861474 1581955288
9 9 1 9 12 3 3 3 3 3 3 3 3 3 3 3 3 3
1582126320 1586986276 1587139950
3 3 3
First join delta-metric data with model auxiliary data.
model_deltas = model_deltas %>%
merge(language_model_data, by = c("seed", "training", "model"), all=T) %>%
drop_na()
model_deltas
Also join on the original linear model data, rather than collapsing to delta-metrics. This will support regressions later on that don’t collapse across folds.
# Exclude ordered-neurons from all analyses.
model_deltas <- model_deltas %>%
filter(model != "ordered-neurons")
all_data %>% ggplot(aes(x=corpus)) + geom_bar()
print(all_data %>% group_by(corpus) %>% summarise(n=n()))
all_data %>%
ggplot(aes(x=freq, color=corpus)) + geom_density()
all_data %>%
ggplot(aes(x=len, color=corpus)) + geom_density()
all_data %>%
ggplot(aes(x=surprisal, color=corpus)) + geom_density()
model_deltas %>%
ggplot(aes(x=sg_score, y=delta_test_mean)) +
geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem)) +
geom_smooth(method="lm", se=T) +
geom_point(stat="identity", position="dodge", alpha=1, size=3, aes(color=training_vocab, shape=model)) +
ylab(metric) +
xlab("Syntax Generalization Score") +
ggtitle("Syntactic Generalization vs. Predictive Power") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"gptbpe"="#888888")) +
facet_grid(~corpus, scales="free") +
theme(axis.text=element_text(size=14),
strip.text.x = element_text(size=14),
legend.text=element_text(size=14),
axis.title=element_text(size=18),
legend.position = "bottom")
#ggsave("./cogsci_images/sg_loglik.png",height=5,width=6)
We control for effects of perplexity by relating the residuals of a performance ~ PPL regression to SG score.
# Prepare a residualized regression for x1 onto y, controlling for the effects of x2.
d_resid = model_deltas %>%
drop_na() %>%
group_by(corpus) %>%
# Residualize delta metric w.r.t PPL for each model--training--seed within
# training vocabulary
mutate(resid.delta = resid(lm(delta_test_mean ~ training_vocab:test_ppl))) %>%
# Residualize SG score w.r.t. PPL for each model--training--seed
# within training vocabulary
mutate(resid.sg = resid(lm(sg_score ~ training_vocab:test_ppl))) %>%
ungroup()
# # Compute summary statistics across model--training--seed--corpus.
# group_by(model, training_vocab, corpus, seed) %>%
# summarise(resid.delta.mean = mean(resid.delta),
# resid.delta.sem = sd(resid.delta) / sqrt(length(resid.delta)),
# resid.sg.mean = mean(resid.sg),
# resid.sg.sem = sd(resid.sg) / sqrt(length(resid.sg))) %>%
# ungroup()
# Now plot residual vs SG
d_resid %>%
#filter(corpus != "bnc-brown") %>%
ggplot(aes(x=resid.sg, y=resid.delta)) +
theme_bw() +
scale_shape_manual(values = c(16, 17, 15, 18)) +
geom_smooth(method="lm", se=T, alpha=0.3) +
geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model, color=training_vocab)) +
ylab(paste("Residual", metric)) +
xlab("Residual Syntax Generalization Score") +
ggtitle("Syntactic Generalization vs. Predictive Power") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"gptbpe"="#f0941f")) +
facet_grid(.~corpus, scales="free") +
theme(axis.text=element_text(size=14),
strip.text.x = element_text(size=14),
legend.text=element_text(size=14),
axis.title=element_text(size=18),
legend.position = "right")
ggsave("../images/cuny2020/dll_sg.png",height=4.5,width=9)
do_stepwise_regression = function(cur_corpus) {
regression_data = model_deltas %>%
filter(corpus == cur_corpus)
print("----------------------")
print(cur_corpus)
lm1 = lm(delta_test_mean ~ training_vocab:test_ppl, data = regression_data)
lm2 = lm(delta_test_mean ~ training_vocab:test_ppl + sg_score, data = regression_data)
print(anova(lm1, lm2))
summary(lm2)
}
do_stepwise_regression("bnc-brown")
[1] "----------------------"
[1] "bnc-brown"
Analysis of Variance Table
Model 1: delta_test_mean ~ training_vocab:test_ppl
Model 2: delta_test_mean ~ training_vocab:test_ppl + sg_score
Res.Df RSS Df Sum of Sq F Pr(>F)
1 23 0.00024405
2 22 0.00023954 1 4.508e-06 0.414 0.5266
Call:
lm(formula = delta_test_mean ~ training_vocab:test_ppl + sg_score,
data = regression_data)
Residuals:
Min 1Q Median 3Q Max
-0.0049611 -0.0018306 -0.0005214 0.0006757 0.0064101
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 1.316e-02 3.624e-03 3.632 0.00147 **
sg_score 3.321e-03 5.161e-03 0.643 0.52658
training_vocabbllip-lg:test_ppl 1.925e-05 3.358e-05 0.573 0.57227
training_vocabbllip-md:test_ppl -4.247e-05 2.799e-05 -1.518 0.14336
training_vocabbllip-sm:test_ppl -5.112e-05 2.492e-05 -2.051 0.05233 .
training_vocabbllip-xs:test_ppl -6.235e-05 1.663e-05 -3.750 0.00111 **
training_vocabgptbpe:test_ppl -1.119e-05 8.169e-06 -1.370 0.18450
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 0.0033 on 22 degrees of freedom
Multiple R-squared: 0.642, Adjusted R-squared: 0.5444
F-statistic: 6.576 on 6 and 22 DF, p-value: 0.0004362
do_stepwise_regression("dundee")
[1] "----------------------"
[1] "dundee"
Analysis of Variance Table
Model 1: delta_test_mean ~ training_vocab:test_ppl
Model 2: delta_test_mean ~ training_vocab:test_ppl + sg_score
Res.Df RSS Df Sum of Sq F Pr(>F)
1 23 9.6740e-05
2 22 9.6289e-05 1 4.5086e-07 0.103 0.7513
Call:
lm(formula = delta_test_mean ~ training_vocab:test_ppl + sg_score,
data = regression_data)
Residuals:
Min 1Q Median 3Q Max
-0.0025554 -0.0011305 -0.0007635 0.0011585 0.0038662
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 7.120e-03 2.298e-03 3.098 0.00525 **
sg_score 1.050e-03 3.272e-03 0.321 0.75127
training_vocabbllip-lg:test_ppl -1.188e-05 2.129e-05 -0.558 0.58250
training_vocabbllip-md:test_ppl -1.813e-05 1.774e-05 -1.022 0.31796
training_vocabbllip-sm:test_ppl -2.308e-05 1.580e-05 -1.461 0.15826
training_vocabbllip-xs:test_ppl -2.271e-05 1.054e-05 -2.154 0.04243 *
training_vocabgptbpe:test_ppl -1.894e-06 5.179e-06 -0.366 0.71816
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 0.002092 on 22 degrees of freedom
Multiple R-squared: 0.3203, Adjusted R-squared: 0.135
F-statistic: 1.728 on 6 and 22 DF, p-value: 0.1615
do_stepwise_regression("natural-stories")
[1] "----------------------"
[1] "natural-stories"
Analysis of Variance Table
Model 1: delta_test_mean ~ training_vocab:test_ppl
Model 2: delta_test_mean ~ training_vocab:test_ppl + sg_score
Res.Df RSS Df Sum of Sq F Pr(>F)
1 23 3.9684e-05
2 22 3.7375e-05 1 2.3089e-06 1.3591 0.2562
Call:
lm(formula = delta_test_mean ~ training_vocab:test_ppl + sg_score,
data = regression_data)
Residuals:
Min 1Q Median 3Q Max
-0.0022223 -0.0008183 0.0002092 0.0009282 0.0023418
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 1.076e-02 1.432e-03 7.512 1.65e-07 ***
sg_score -2.377e-03 2.039e-03 -1.166 0.2562
training_vocabbllip-lg:test_ppl -6.333e-05 1.326e-05 -4.774 9.11e-05 ***
training_vocabbllip-md:test_ppl -6.220e-05 1.105e-05 -5.627 1.17e-05 ***
training_vocabbllip-sm:test_ppl -6.636e-05 9.844e-06 -6.742 8.93e-07 ***
training_vocabbllip-xs:test_ppl -3.823e-05 6.568e-06 -5.821 7.41e-06 ***
training_vocabgptbpe:test_ppl -8.201e-06 3.227e-06 -2.541 0.0186 *
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 0.001303 on 22 degrees of freedom
Multiple R-squared: 0.7632, Adjusted R-squared: 0.6986
F-statistic: 11.82 on 6 and 22 DF, p-value: 6.283e-06
# The residualized analysis and the stepwise regression analysis
# should yield the same coefficients for the SG score variable.
#
# Below, we compute the slope coefficient for the SG term in the
# residualized analyses.
#
# These coefficients should match those found in the stepwise
# regression for `sg_score` above.
d_resid %>% group_by(corpus) %>%
group_modify(~tidy(lm(resid.delta ~ training_vocab:test_ppl + resid.sg, data=.))
%>% filter(term == "resid.sg")) %>%
select(corpus, estimate)
model_deltas %>%
mutate(test_ppl = if_else(test_ppl > 500, 329.9, test_ppl)) %>%
ggplot(aes(x=test_ppl, y=delta_test_mean, color=training_vocab, fill = training_vocab, ymin=0)) +
theme_bw() +
geom_text(aes(x=275, y=0, label = c("//"))) +
geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), alpha=0.4) +
#geom_smooth(method="lm", se=F) +
geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model, color = training_vocab)) +
ylab(metric) +
xlab("Test Perplexity") +
#coord_cartesian(ylim = c(1, 16)) +
ggtitle("Test Perplexity vs. Predictive Power") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"gptbpe"="#f0941f")) +
scale_shape_manual(values = c(16, 17, 15, 18)) +
scale_x_continuous(labels=c(0, 50, 100, 150, 200, 250, 500 ,550), breaks=c(0, 50, 100, 150, 200, 250, 300, 350), minor_breaks = NULL) +
scale_y_continuous(limits = c(0, NA), expand = c(0,0)) +
facet_wrap(~corpus, scales="free") +
coord_cartesian(clip="off") +
theme(axis.text=element_text(size=12),
strip.text.x = element_text(size=12),
legend.text=element_text(size=12),
axis.title=element_text(size=12),
legend.position = "right")
ggsave("../images/cuny2020/ppl_loglik.png",height=4.2,width=12)
dll_cor_test = function(df){
df %>%
summarise(
cor = cor.test(df$delta_test_mean, df$test_ppl)$estimate,
p = cor.test(df$delta_test_mean, df$test_ppl)$p.value
)
}
model_deltas %>%
filter(model != "5gram") %>%
group_by(training, corpus) %>%
mutate(n = n()) %>%
ungroup() %>%
filter(n > 2) %>%
group_by(training, corpus) %>%
do({ dll_cor_test(.) }) %>%
ungroup() %>%
arrange(corpus)
NA
model_deltas %>%
mutate(train_size = log(train_size)) %>%
mutate(bpe = if_else(training_vocab == "gptbpe", "yes", "no"),
bpe = as.factor(bpe)) %>%
ggplot(aes(x=train_size, y=delta_test_mean, color=model)) +
theme_bw() +
geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), width = 0.1) +
geom_smooth(method="lm", se=T, alpha=0.2) +
geom_point(stat="identity", position="dodge", alpha=1, size=3, aes(shape=bpe)) +
ylab(metric) +
xlab("Log Million Training Tokens") +
ggtitle("Training Size vs. Predictive Power") +
facet_grid(.~corpus, scales="free") +
#scale_color_manual(values = c("#A42EF1", "#3894C8")) +
theme(axis.text=element_text(size=12),
strip.text.x = element_text(size=12),
legend.text=element_text(size=8),
legend.title=element_text(size=8),
axis.title=element_text(size=14),
legend.position = "bottom",
legend.direction = "horizontal",
legend.key.width = unit(0.3,"cm"),
legend.spacing.x = unit(0.1, 'cm'))
#ggsave("../images/cuny2020/training_loglik.png",height=5,width=5)
model_cor_test = function(df){
df %>%
summarise(cor = cor.test(df$train_size, df$delta_test_mean)$estimate,
p = cor.test(df$train_size, df$delta_test_mean)$p.value)
}
model_deltas %>%
group_by(model, corpus) %>%
do({model_cor_test(.)}) %>%
ungroup() %>%
arrange()
NA
NA
model_deltas %>%
mutate(test_ppl = if_else(test_ppl > 500, 329.9, test_ppl)) %>%
mutate(train_size = log(train_size)) %>%
mutate(bpe = if_else(training_vocab == "gptbpe", "yes", "no"),
bpe = as.factor(bpe)) %>%
ggplot(aes(x=test_ppl, y=sg_score, color=training_vocab)) +
theme_bw() +
#geom_smooth(method="lm", se=T, alpha=0.2) +
geom_point(stat="identity", position="dodge", alpha=0.6, size=5, aes(shape=model)) +
geom_text(aes(x=275, y=0, label = c("//"))) +
ylab("SG SCore") +
xlab("Test Perplexity") +
ggtitle("Test PPL vs. SG Score") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"gptbpe"="#f0941f")) +
scale_shape_manual(values = c(16, 17, 15, 18)) +
scale_x_continuous(labels=c(0, 50, 100, 150, 200, 250, 500 ,550), breaks=c(0, 50, 100, 150, 200, 250, 300, 350), minor_breaks = NULL) +
scale_y_continuous(limits = c(0, 1), expand = c(0,0)) +
theme(axis.text=element_text(size=12),
strip.text.x = element_text(size=12),
legend.text=element_text(size=8),
legend.title=element_text(size=8),
axis.title=element_text(size=14),
legend.position = "none",
legend.direction = "horizontal",
legend.key.width = unit(0.3,"cm"),
legend.spacing.x = unit(0.1, 'cm'))
ggsave("../images/cuny2020/ppl_sg.png",height=4.5,width=3)
model_deltas %>%
mutate(train_size = log(train_size)) %>%
mutate(bpe = if_else(training_vocab == "gptbpe", "yes", "no"),
bpe = as.factor(bpe)) %>%
ggplot(aes(x=train_size, y=sg_score, color=model)) +
theme_bw() +
geom_smooth(method="lm", se=T, alpha=0.2) +
geom_point(stat="identity", position="dodge", alpha=1, size=3, aes(shape=bpe)) +
ylab("SG SCore") +
xlab("Log Million Training Tokens") +
ggtitle("Training Size vs. SG Score") +
#scale_color_manual(values = c("#A42EF1", "#3894C8")) +
#facet_grid(~model, scales="free") +
theme(axis.text=element_text(size=12),
strip.text.x = element_text(size=12),
legend.text=element_text(size=8),
legend.title=element_text(size=8),
axis.title=element_text(size=14),
legend.position = "bottom",
legend.direction = "horizontal",
legend.key.width = unit(0.3,"cm"),
legend.spacing.x = unit(0.1, 'cm'))
#ggsave("../images/cuny2020/training_sg.png",height=5,width=4)
model_cor_test = function(df){
df %>%
summarise(cor = cor.test(df$train_size, df$sg_score)$estimate,
p = cor.test(df$train_size, df$sg_score)$p.value)
}
model_deltas %>%
group_by(model) %>%
do({model_cor_test(.)}) %>%
ungroup()
NA
NA
NA
all_data %>%
ggplot(aes(x=surprisal, color=model)) +
theme_bw() +
geom_density() +
facet_grid(~corpus) +
coord_cartesian(xlim = c(0, 21)) +
theme(panel.spacing = unit(2.5, "cm"))
ggsave("../images/cuny2020/surp_corr_marginals.png",height=1.5,width=11)
k = 1.97
fit_gams = function(df, corpus, model, training_vocab){
print(paste(corpus, model, training_vocab))
if(corpus == "dundee") {
m = gam(psychometric ~ s(surprisal, bs = 'cr', k = 20) + s(prev_surp, bs = 'cr', k = 20) + te(freq, len, bs = 'cr') + te(prev_freq, prev_len, bs = 'cr'), data = df)
} else {
m = gam(psychometric ~ s(surprisal, bs = 'cr', k = 20) + s(prev_surp, bs = 'cr', k = 20) + s(prev2_surp, bs = 'cr', k = 20) + s(prev3_surp, bs = 'cr', k = 20) + te(freq, len, bs = 'cr') + te(prev_freq, prev_len, bs = 'cr') + te(prev2_freq, prev2_len, bs = 'cr') + te(prev3_freq, prev3_len, bs = 'cr'), data = df)
}
newdata = data.frame(surprisal=seq(0,20,by=0.1),prev_surp=seq(0,20,by=0.1),prev2_surp=seq(0,20,by=0.1),prev3_surp=seq(0,20,by=0.1),freq=0,prev_freq=0,prev2_freq=0,prev3_freq=0,len=0,prev_len=0,prev2_len=0,prev3_len=0)
m.pred = predict(m,newdata=newdata,se.fit=TRUE, type="link", exclude=c("te(freq, len, bs = 'cr')","te(prev_freq, prev_len, bs = 'cr')","te(prev2_freq, prev2_len, bs = 'cr')", "te(prev3_freq, prev3_len, bs = 'cr')"))
newdata = newdata %>%
mutate(
y = m.pred$fit,
y_lower = m.pred$fit-k*m.pred$se.fit,
y_upper = m.pred$fit+k*m.pred$se.fit,
model = model, corpus = corpus, training_vocab = training_vocab
)
return(newdata)
}
smooths = all_data %>%
mutate(
training_vocab=as.factor(ifelse(str_detect(training, "gptbpe"), "gptbpe", as.character(training))),
training_source=as.factor(str_replace(as.character(training), "-gptbpe", ""))) %>%
group_by(training, model, corpus) %>%
do({ fit_gams(., unique(.$corpus), unique(.$model), unique(.$training_vocab)) }) %>%
ungroup()
[1] "bnc-brown 5gram bllip-lg"
|=== | 2% ~4 m remaining [1] "dundee 5gram bllip-lg"
|====== | 4% ~2 m remaining [1] "natural-stories 5gram bllip-lg"
|========== | 6% ~3 m remaining [1] "bnc-brown gpt2 bllip-lg"
|============= | 7% ~3 m remaining [1] "dundee gpt2 bllip-lg"
|================ | 9% ~2 m remaining [1] "natural-stories gpt2 bllip-lg"
|==================== | 11% ~3 m remaining [1] "bnc-brown rnng bllip-lg"
|======================= | 13% ~3 m remaining [1] "dundee rnng bllip-lg"
|========================== | 15% ~3 m remaining [1] "natural-stories rnng bllip-lg"
|============================== | 17% ~3 m remaining [1] "bnc-brown vanilla bllip-lg"
|================================= | 19% ~3 m remaining [1] "dundee vanilla bllip-lg"
|==================================== | 20% ~3 m remaining [1] "natural-stories vanilla bllip-lg"
|======================================== | 22% ~3 m remaining [1] "bnc-brown gpt2 gptbpe"
|=========================================== | 24% ~3 m remaining [1] "dundee gpt2 gptbpe"
|============================================== | 26% ~3 m remaining [1] "natural-stories gpt2 gptbpe"
|================================================== | 28% ~3 m remaining [1] "bnc-brown 5gram bllip-md"
|===================================================== | 30% ~2 m remaining [1] "dundee 5gram bllip-md"
|======================================================== | 31% ~2 m remaining [1] "natural-stories 5gram bllip-md"
|============================================================ | 33% ~2 m remaining [1] "bnc-brown gpt2 bllip-md"
|=============================================================== | 35% ~2 m remaining [1] "dundee gpt2 bllip-md"
|=================================================================== | 37% ~2 m remaining [1] "natural-stories gpt2 bllip-md"
|====================================================================== | 39% ~2 m remaining [1] "bnc-brown rnng bllip-md"
|========================================================================= | 41% ~2 m remaining [1] "dundee rnng bllip-md"
|============================================================================= | 43% ~2 m remaining [1] "natural-stories rnng bllip-md"
|================================================================================ | 44% ~2 m remaining [1] "bnc-brown vanilla bllip-md"
|=================================================================================== | 46% ~2 m remaining [1] "dundee vanilla bllip-md"
|======================================================================================= | 48% ~2 m remaining [1] "natural-stories vanilla bllip-md"
|========================================================================================== | 50% ~2 m remaining [1] "bnc-brown gpt2 gptbpe"
|============================================================================================= | 52% ~2 m remaining [1] "dundee gpt2 gptbpe"
|================================================================================================= | 54% ~2 m remaining [1] "natural-stories gpt2 gptbpe"
|==================================================================================================== | 56% ~2 m remaining [1] "bnc-brown 5gram bllip-sm"
|======================================================================================================= | 57% ~1 m remaining [1] "dundee 5gram bllip-sm"
|=========================================================================================================== | 59% ~1 m remaining [1] "natural-stories 5gram bllip-sm"
|============================================================================================================== | 61% ~1 m remaining [1] "bnc-brown rnng bllip-sm"
|================================================================================================================= | 63% ~1 m remaining [1] "dundee rnng bllip-sm"
|===================================================================================================================== | 65% ~1 m remaining [1] "natural-stories rnng bllip-sm"
|======================================================================================================================== | 67% ~1 m remaining [1] "bnc-brown vanilla bllip-sm"
|============================================================================================================================ | 69% ~1 m remaining [1] "dundee vanilla bllip-sm"
|=============================================================================================================================== | 70% ~59 s remaining [1] "natural-stories vanilla bllip-sm"
|================================================================================================================================== | 72% ~55 s remaining [1] "bnc-brown gpt2 gptbpe"
|====================================================================================================================================== | 74% ~52 s remaining [1] "dundee gpt2 gptbpe"
|========================================================================================================================================= | 76% ~48 s remaining [1] "natural-stories gpt2 gptbpe"
|============================================================================================================================================ | 78% ~45 s remaining [1] "bnc-brown 5gram bllip-xs"
|================================================================================================================================================ | 80% ~43 s remaining [1] "dundee 5gram bllip-xs"
|=================================================================================================================================================== | 81% ~39 s remaining [1] "natural-stories 5gram bllip-xs"
|====================================================================================================================================================== | 83% ~37 s remaining [1] "bnc-brown rnng bllip-xs"
|========================================================================================================================================================== | 85% ~33 s remaining [1] "dundee rnng bllip-xs"
|============================================================================================================================================================= | 87% ~29 s remaining [1] "natural-stories rnng bllip-xs"
|================================================================================================================================================================ | 89% ~25 s remaining [1] "bnc-brown vanilla bllip-xs"
|==================================================================================================================================================================== | 91% ~21 s remaining [1] "dundee vanilla bllip-xs"
|======================================================================================================================================================================= | 93% ~17 s remaining [1] "natural-stories vanilla bllip-xs"
|========================================================================================================================================================================== | 94% ~12 s remaining [1] "bnc-brown gpt2 gptbpe"
|============================================================================================================================================================================== | 96% ~8 s remaining [1] "dundee gpt2 gptbpe"
|================================================================================================================================================================================= | 98% ~4 s remaining [1] "natural-stories gpt2 gptbpe"
|=====================================================================================================================================================================================|100% ~0 s remaining
write.csv(smooths, "../data/gam_smooths.csv")
smooths %>%
#mutate(training_model = paste(training, "_", model, sep="")) %>%
#filter(training_vocab == "bllip-xs" | training_vocab == "gptbpe" | training_vocab == "bllip-lg") %>%
#filter(training != "bllip-md-gptbpe" & training != "bllip-sm-gptbpe") %>%
mutate(training = as.character(training),
training = if_else(training == "bllip-lg-gptbpe", "bpe \n bllip-lg", training),
training = if_else(training == "bllip-md-gptbpe", "bpe \n bllip-md", training),
training = if_else(training == "bllip-sm-gptbpe", "bpe \n bllip-sm", training),
training = if_else(training == "bllip-xs-gptbpe", "bpe \n bllip-xs", training)) %>%
ggplot(aes(x=surprisal, y=y, fill=training_vocab, linetype=model)) +
theme_bw() +
geom_line(size=0.5, aes(color=training_vocab)) +
#geom_line(aes(y=y_lower), linetype="dashed") +
geom_ribbon(aes(ymin=y_lower,ymax=y_upper), alpha=0.3) +
#geom_line(aes(y=y_upper), linetype="dashed") +
facet_grid(corpus~training+model, scales="free") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"gptbpe"="#f0941f")) +
scale_fill_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"gptbpe"="#f0941f")) +
scale_x_continuous(labels=c(0, 10, 20), breaks=c(0, 10, 20), minor_breaks = NULL) +
ylab("Reading Time") +
theme(legend.position = "bottom")
ggsave("../images/cuny2020/gam_surp_corr.png", height=5,width=12)
NA
NA
NA
all_data %>%
filter(model == "gpt2", corpus == "dundee") %>%
filter(surprisal<21) %>%
mutate(bpe=str_detect(training, "bpe"),
training_source=str_replace(training, "-gptbpe", "")) %>%
ggplot(aes(x=surprisal, y=psychometric, color=training_source, linetype=bpe)) +
theme_bw() +
#stat_smooth(se=T, alpha=0.5) +
geom_smooth(method = "gam", formula = psychometric ~ s(surprisal, bs = 'cr', k = 20) + s(prev_surp, bs = 'cr', k = 20) + te(freq, len, bs = 'cr') + te(prev_freq, prev_len, bs = 'cr', se = F)) +
#geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
#geom_point(stat="identity", position="dodge", alpha=1, size=3) +
ylab("Processing Time (ms)") +
xlab("Surprisal (bits)") +
ggtitle("Surprisal vs. Reading Time / Gaze Duration") +
facet_wrap(model ~ corpus, scales="free", ncol=3, strip.position = c("right")) +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"bllip-lg-gptbpe"="#888888",
"bllip-md-gptbpe"="#888888",
"bllip-sm-gptbpe"="#888888",
"bllip-xs-gptbpe"="#888888")) +
coord_cartesian(xlim = c(0, 21)) +
theme(axis.text=element_text(size=10),
axis.text.y = element_text(size = 10),
strip.text.x = element_text(size=10),
legend.text=element_text(size=10),
axis.title=element_text(size=12),
legend.position = "right")
#ggsave("../images/cuny2020/surp_corr.png",height=6,width=12)
corr_test = function(df){
df %>%
summarise(
cor = cor.test(df$surprisal, df$psychometric)$estimate
)
}
all_data %>%
group_by(model, training, corpus, seed) %>%
do({ cor = corr_test(.)}) %>%
ungroup()
NA
all_data %>%
#filter(surprisal < 15, surprisal > 0) %>%
filter(model == "vanilla") %>%
ggplot(aes(x=surprisal, y=psychometric)) +
#stat_smooth(se=T, alpha=0.5) +
#geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
geom_point(alpha=0.1) + #stat="identity", position="dodge", alpha=1, size=3) +
ylab("Processing Time (ms)") +
xlab("Surprisal (bits)") +
ggtitle("Surprisal vs. Reading Time / Gaze Duration: Vanilla") +
facet_grid(corpus~training, scales = "free")
# scale_color_manual(values = c("bllip-lg"="#440154FF",
# "bllip-md"="#39568CFF",
# "bllip-sm"="#1F968BFF",
# "bllip-xs"="#73D055FF",
# "bllip-lg-gptbpe"="#888888",
# "bllip-md-gptbpe"="#888888",
# "bllip-sm-gptbpe"="#888888",
# "bllip-xs-gptbpe"="#888888"))
all_data %>%
filter(corpus == "dundee", model == "vanilla", training == "bllip-lg", surprisal > 20, psychometric < 300)
print(full_residuals %>% filter(corpus == "dundee", model == "vanilla", training == "bllip-lg") %>% arrange(desc(resid)))
full_residuals %>% filter(corpus == "dundee", model == "vanilla", training == "bllip-lg") %>% arrange(desc(resid)) %>% filter(resid > 150) %>%
ggplot(aes(x=surprisal)) + geom_density()
all_data %>%
#filter(surprisal < 15, surprisal > 0) %>%
filter(model == "rnng") %>%
ggplot(aes(x=surprisal, y=psychometric)) +
#stat_smooth(se=T, alpha=0.5) +
#geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
geom_point(alpha=0.1) + #stat="identity", position="dodge", alpha=1, size=3) +
ylab("Processing Time (ms)") +
xlab("Surprisal (bits)") +
ggtitle("Surprisal vs. Reading Time / Gaze Duration: RNNG") +
facet_grid(corpus~training, scales = "free")
all_data %>%
filter(corpus == "dundee", model == "rnng", training == "bllip-lg", surprisal > 20, psychometric < 300)
print(full_residuals %>% filter(corpus == "dundee", model == "rnng", training == "bllip-lg") %>% arrange(desc(resid)))
full_residuals %>% filter(corpus == "dundee", model == "rnng", training == "bllip-lg") %>% arrange(desc(resid)) %>% filter(resid > 150) %>%
ggplot(aes(x=surprisal)) + geom_density()
ngram_resids = full_residuals %>% filter(model == "5gram", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), surprisal=mean(surprisal), resid=mean(resid))
vanilla_resids = full_residuals %>% filter(model == "vanilla", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), surprisal=mean(surprisal), resid=mean(resid))
resids_joined = ngram_resids %>% left_join(vanilla_resids, by=c("corpus", "code"), suffix=c(".ngram", ".vanilla"))
resids_joined %>%
ggplot(aes(x=resid.ngram, y=resid.vanilla)) + geom_point() + geom_abline(slope=1, color="red") +
facet_grid(~corpus)
resids_joined %>%
mutate(resid_diff=resid.ngram - resid.vanilla) %>%
ggplot(aes(x=resid_diff)) + geom_density() +
facet_grid(~corpus)
resids_joined %>%
mutate(resid_diff=abs(resid.ngram) - abs(resid.vanilla),
big=resid_diff < -10) %>%
ggplot(aes(x=surprisal.ngram, color=big)) + geom_density() + facet_grid(~corpus) +
ggtitle("ngram surprisal of high-improvement tokens (relative to vanilla)")
resids_joined %>%
mutate(resid_abs_diff=abs(resid.ngram - resid.vanilla)) %>%
ggplot(aes(x=freq.ngram, y=resid_abs_diff)) + geom_point(alpha=0.1) + geom_smooth()
gpt_resids = full_residuals %>% filter(model == "gpt2", training == "bllip-sm-gptbpe") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), surprisal=mean(surprisal), resid=mean(resid))
vanilla_resids = full_residuals %>% filter(model == "vanilla", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), surprisal=mean(surprisal), resid=mean(resid))
resids_joined = gpt_resids %>% left_join(vanilla_resids, by=c("corpus", "code"), suffix=c(".gpt", ".vanilla"))
resids_joined %>%
ggplot(aes(x=resid.gpt, y=resid.vanilla)) + geom_point() + geom_abline(slope=1, color="red") +
facet_grid(~corpus)
resids_joined %>%
mutate(resid_diff=resid.gpt - resid.vanilla) %>%
ggplot(aes(x=resid_diff)) + geom_density() +
facet_grid(~corpus)
resids_joined %>%
mutate(resid_diff=abs(resid.gpt) - abs(resid.vanilla),
big=resid_diff < -10) %>%
ggplot(aes(x=surprisal.gpt, color=big)) + geom_density() + facet_grid(~corpus) +
ggtitle("gpt surprisal of high-improvement tokens (relative to vanilla)")
resids_joined %>%
mutate(resid_abs_diff=abs(resid.gpt - resid.vanilla)) %>%
ggplot(aes(x=freq.gpt, y=resid_abs_diff)) + geom_point(alpha=0.1) + geom_smooth()
resid_deltas = full_residuals %>% right_join(baseline_residuals, by=c("corpus", "code", "model", "training", "seed"), suffix=c(".full", ".baseline")) %>%
select(resid.baseline, resid.full, code, surprisal.full, psychometric.full, model, training, seed, corpus, len.full) %>%
mutate(resid.baseline.pol = if_else(resid.baseline > 0, 1, 0),
resid.full.pol = if_else(resid.full > 0, 1, 0)) %>%
mutate(resid.baseline = abs(resid.baseline),
resid.full = abs(resid.full)) %>%
mutate(resid_delta=resid.baseline - resid.full, #positive is better
training_source=as.factor(str_replace(training, "-gptbpe", "")),
bpe=str_detect(training, "gptbpe"))
r = resid_deltas %>%
filter(resid.full.pol != resid.baseline.pol)
resid_deltas %>%
ggplot(aes(x=surprisal.full, y=resid_delta, color=training)) +
facet_grid(model~corpus) +
geom_point(alpha=0.1, size=0.5)
language_model_data %>% filter(model == "gpt2")
resid_deltas %>%
group_by(corpus) %>%
mutate(psychometric = scale(psychometric.full)) %>%
ungroup() %>%
ggplot(aes(x=psychometric)) +
theme_bw() +
geom_density() +
geom_vline(xintercept = 0, color = "grey") +
facet_grid(.~corpus) +
#coord_cartesian(xlim = c(-2, 4)) +
theme(axis.title.x=element_blank(),
axis.text.x=element_blank(),
axis.ticks.x=element_blank())
#ggsave("length.png", width = 8, height = 1)
log_lik_deltas %>%
#resid_deltas %>%
#filter(resid.full.pol == resid.baseline.pol) %>%
group_by(corpus) %>%
mutate(psychometric = scale(psychometric)) %>%
ungroup() %>%
#filter(psychometric < 4) %>%
#filter(len.full <= 10) %>%
ggplot(aes(x = psychometric, y = delta_log_lik, color = model)) +
theme_bw() +
facet_grid(. ~ corpus, scales = "free") +
#geom_rug(alpha = 0.003, sides = "b") +
geom_hline(yintercept=0, color = "blue") +
geom_vline(xintercept = 0, color = "grey") +
geom_smooth(se = T, alpha = 0.2) +
coord_cartesian(ylim = c(-0.1, 0.2), xlim = c(-2, 4)) +
theme(legend.position = "bottom",
strip.text.x = element_blank())
ggsave( "./resid_psycho.png", height = 4, width = 8)
NA
NA
resid_deltas %>%
group_by(corpus) %>%
mutate(psychometric = scale(psychometric.full)) %>%
ungroup() %>%
ggplot(aes(x=len.full)) +
theme_bw() +
geom_histogram(bins = 20) +
geom_vline(xintercept = 0, color = "grey") +
facet_grid(.~corpus) +
coord_cartesian(xlim = c(1, 10)) +
theme(axis.title.x=element_blank(),
axis.text.x=element_blank(),
axis.ticks.x=element_blank())
ggsave("length_maringals.png", width = 8, height = 1)
log_lik_deltas %>%
#resid_deltas %>%
#filter(resid.full.pol == resid.baseline.pol) %>%
group_by(corpus) %>%
mutate(psychometric = scale(psychometric)) %>%
ungroup() %>%
#filter(psychometric < 4) %>%
#filter(len.full <= 10) %>%
ggplot(aes(x = len, y = delta_log_lik, color = model)) +
theme_bw() +
facet_grid(. ~ corpus, scales = "free") +
#geom_rug(alpha = 0.003, sides = "b") +
geom_hline(yintercept=0, color = "blue") +
geom_vline(xintercept = 0, color = "grey") +
geom_smooth(se = T, alpha = 0.2) +
coord_cartesian(ylim = c(-0.02, 0.06), xlim = c(1, 10)) +
theme(legend.position = "bottom",
strip.text.x = element_blank())
ggsave( "./resid_length.png", height = 4, width = 8)
word_norm = log_lik_deltas %>%
drop_na() %>%
group_by(word, corpus, model, training, seed) %>%
mutate(psychoword = scale(psychometric),
norm_surp = scale(surprisal))
word_norm %>%
ggplot(aes(x=norm_surp)) +
facet_grid(~corpus) +
geom_density() +
coord_cartesian(xlim = c(-2, 5)) +
geom_vline(xintercept = 0, color = "grey") +
theme_bw() +
theme(axis.title.x=element_blank(),
axis.text.x=element_blank(),
axis.ticks.x=element_blank())
ggsave("surp_maringals.png", width = 8, height = 1)
NA
NA
word_norm %>%
ggplot(aes(x = psychoword, y = norm_surp, color = model)) +
theme_bw() +
facet_grid(. ~ corpus, scales = "free") +
#geom_rug(alpha = 0.003, sides = "b") +
geom_hline(yintercept=0, color = "blue") +
geom_vline(xintercept = 0, color = "grey") +
geom_smooth(se = T, alpha = 0.2) +
#coord_cartesian(ylim = c(-0.05, 0.1), xlim = c(-2, 3)) +
theme(legend.position = "bottom")
#ggsave( "./resid_length.png", height = 4, width = 8)
word_norm %>%
ggplot(aes(x = norm_surp, y = delta_log_lik, color = model)) +
theme_bw() +
facet_grid( . ~ corpus, scales = "free") +
#geom_rug(alpha = 0.003, sides = "b") +
geom_hline(yintercept=0, color = "blue") +
geom_vline(xintercept = 0, color = "grey") +
geom_smooth(se = T, alpha = 0.2) +
coord_cartesian(ylim = c(-0.05, 0.07), xlim = c(-2, 5)) +
theme(legend.position = "bottom")
ggsave( "./norm_surp.png", height = 4, width = 8)
ngram_highsurp = word_norm %>%
ungroup() %>%
filter(corpus == "dundee", norm_surp > 2, model == "5gram") %>%
select(code)
ngram_highsurp = ngram_highsurp$code
z = word_norm %>%
ungroup() %>%
filter(! code %in% ngram_highsurp) %>%
filter(corpus == "dundee")
write.csv(z, "ngram-ablate.csv")